import math

import dartpy as dart
import numpy as np
import pytest


def read_world(uri: dart.common.Uri):
    options = dart.io.SdfParser.Options()
    options.m_default_root_joint_type = dart.io.SdfParser.RootJointType.Fixed
    world = dart.io.SdfParser.read_world(uri, options)
    for i in range(world.get_num_skeletons()):
        skel = world.get_skeleton(i)
        for j in range(skel.get_num_joints()):
            joint = skel.get_joint(j)
            joint.set_limit_enforcement(True)
    return world


def test_static():
    # Load world
    world = read_world("dart://sample/sdf/test/force_torque_test.world")
    assert world is not None
    assert world.get_num_skeletons() == 1

    # Check if the world is correctly loaded
    model_1 = world.get_skeleton(0)
    assert model_1 is not None
    assert model_1.get_name() == "model_1"
    assert model_1.get_num_body_nodes() == 2
    assert model_1.get_num_joints() == 2

    world.set_gravity([0, 0, -50])

    # Simulate 1 step
    world.step()
    t = world.get_time()

    # Get time step size
    dt = world.get_time_step()
    assert dt > 0
    assert t == pytest.approx(dt)

    # Get joint and get force torque
    link_1 = model_1.get_body_node("link_1")
    assert link_1 is not None
    link_2 = model_1.get_body_node("link_2")
    assert link_2 is not None
    joint_01 = model_1.get_joint("joint_01")
    assert joint_01 is not None
    joint_12 = model_1.get_joint("joint_12")
    assert joint_12 is not None

    tf = dart.Isometry3()

    # Run 10 steps
    for _ in range(10):
        world.step()

        # ----------------------
        # Test joint_01 wrench
        # ----------------------

        # Reference adjustment for the difference of the joint frame conventions
        # between Gazebo and DART
        tf.set_identity()
        tf.set_translation(joint_01.get_transform_from_parent_body_node().translation())
        parent_frame01 = dart.SimpleFrame(
            dart.Frame.world(), "parent_frame01", tf
        )
        tf.set_identity()
        tf.set_translation(joint_01.get_transform_from_child_body_node().translation())
        child_frame01 = dart.SimpleFrame(link_1, "child_frame01", tf)

        parent_f01 = joint_01.get_wrench_to_parent_body_node(parent_frame01)
        assert (parent_f01 == [0, 0, 0, 0, 0, 1000]).all()

        child_f01 = joint_01.get_wrench_to_child_body_node(child_frame01)
        assert (child_f01 == -parent_f01).all()

        # ----------------------
        # Test joint_12 wrench
        # ----------------------

        # Reference adjustment for the difference of the joint frame conventions
        # between Gazebo and DART
        tf.set_identity()
        tf.set_translation(joint_12.get_transform_from_parent_body_node().translation())
        parent_frame12 = dart.SimpleFrame(link_1, "parent_frame12", tf)
        tf.set_identity()
        tf.set_translation(joint_12.get_transform_from_child_body_node().translation())
        child_frame12 = dart.SimpleFrame(link_2, "child_frame12", tf)

        parent_f12 = joint_12.get_wrench_to_parent_body_node(parent_frame12)
        assert (parent_f12 == [0, 0, 0, 0, 0, 500]).all()

        child_f12 = joint_12.get_wrench_to_child_body_node(child_frame12)
        assert (child_f12 == -parent_f12).all()


def test_force_torque_at_joint_limits():
    # Load world
    world = read_world("dart://sample/sdf/test/force_torque_test.world")
    assert world is not None
    assert world.get_num_skeletons() == 1

    # Check if the world is correctly loaded
    model_1 = world.get_skeleton(0)
    assert model_1 is not None
    assert model_1.get_name() == "model_1"
    assert model_1.get_num_body_nodes() == 2
    assert model_1.get_num_joints() == 2

    world.set_gravity([0, 0, -50])

    # Simulate 1 step
    world.step()
    t = world.get_time()

    # Get time step size
    dt = world.get_time_step()
    assert dt > 0
    assert t == pytest.approx(dt)

    # Get joint and get force torque
    link_1 = model_1.get_body_node("link_1")
    assert link_1 is not None
    link_2 = model_1.get_body_node("link_2")
    assert link_2 is not None
    joint_01 = model_1.get_joint("joint_01")
    assert joint_01 is not None
    joint_12 = model_1.get_joint("joint_12")
    assert joint_12 is not None

    # Change gravity so that the top link topples over, then remeasure
    world.set_gravity([-30, 10, -50])

    # Wait for dynamics to be stabilized
    for i in range(2000):
        world.step()

    tf = dart.Isometry3()

    # Run 5 steps
    for _ in range(5):
        world.step()

        # ----------------------
        # Test joint_01 wrench
        # ----------------------

        # Reference adjustment for the difference of the joint frame conventions
        # between Gazebo and DART
        tf.set_identity()
        tf.set_translation(joint_01.get_transform_from_parent_body_node().translation())
        parent_frame01 = dart.SimpleFrame(
            dart.Frame.world(), "parent_frame01", tf
        )
        tf.set_identity()
        tf.set_translation(joint_01.get_transform_from_child_body_node().translation())
        child_frame01 = dart.SimpleFrame(link_1, "child_frame01", tf)

        parent_f01 = joint_01.get_wrench_to_parent_body_node(parent_frame01)
        assert np.isclose(
            parent_f01, [750, 0, -450, 600, -200, 1000], rtol=0.1, atol=4.5
        ).all()

        child_f01 = joint_01.get_wrench_to_child_body_node(child_frame01)
        assert np.isclose(
            child_f01, [-750, -450, 0, -600, 1000, 200], rtol=0.1, atol=4.5
        ).all()

        # ----------------------
        # Test joint_12 wrench
        # ----------------------

        # Reference adjustment for the difference of the joint frame conventions
        # between Gazebo and DART
        tf.set_identity()
        tf.set_translation(joint_12.get_transform_from_parent_body_node().translation())
        parent_frame12 = dart.SimpleFrame(link_1, "parent_frame12", tf)
        tf.set_identity()
        tf.set_translation(joint_12.get_transform_from_child_body_node().translation())
        child_frame12 = dart.SimpleFrame(link_2, "child_frame12", tf)

        parent_f12 = joint_12.get_wrench_to_parent_body_node(parent_frame12)
        assert np.isclose(
            parent_f12, [250, 150, 0, 300, -500, -100], rtol=0.1, atol=0.1
        ).all()

        child_f12 = joint_12.get_wrench_to_child_body_node(child_frame12)
        assert np.isclose(
            child_f12, [-250, -150, 0, -300, 500, 100], rtol=0.1, atol=0.1
        ).all()


def test_force_torque_at_joint_limits_with_external_forces():
    # Load world
    world = read_world("dart://sample/sdf/test/force_torque_test2.world")
    assert world is not None
    assert world.get_num_skeletons() == 1

    # Check if the world is correctly loaded
    model_1 = world.get_skeleton(0)
    assert model_1 is not None
    assert model_1.get_name() == "boxes"
    assert model_1.get_num_body_nodes() == 3
    assert model_1.get_num_joints() == 3
    assert model_1.get_num_dofs() == 2

    # The first joint is fixed joint
    assert model_1.get_root_joint().get_type() == dart.WeldJoint.get_static_type()

    world.set_gravity([0, 0, -50])

    # Simulate 1 step
    world.step()
    t = world.get_time()

    # Get time step size
    dt = world.get_time_step()
    assert dt > 0
    assert t == pytest.approx(dt)

    # Get joint and get force torque
    link_1 = model_1.get_body_node("link1")
    assert link_1 is not None
    link_2 = model_1.get_body_node("link2")
    assert link_2 is not None
    link_3 = model_1.get_body_node("link3")
    assert link_3 is not None
    joint_12 = model_1.get_joint("joint1")
    assert joint_12 is not None
    joint_23 = model_1.get_joint("joint2")
    assert joint_23 is not None

    tf = dart.Isometry3()

    # Run 45005 steps
    kp1 = 5e4
    kp2 = 1e4
    target1 = 0
    target2 = -0.25 * math.pi
    steps = 4500
    for _ in range(steps):
        # PD control
        j1_state = joint_12.get_position(0)
        j2_state = joint_23.get_position(0)
        p1_error = target1 - j1_state
        p2_error = target2 - j2_state
        effort1 = kp1 * p1_error
        effort2 = kp2 * p2_error
        joint_12.set_force(0, effort1)
        joint_23.set_force(0, effort2)

        world.step()

    assert joint_12.get_position(0) == pytest.approx(target1, abs=1e-1)
    assert joint_23.get_position(0) == pytest.approx(target2, abs=1e-1)

    tol = 2

    # ----------------------
    # Test joint_12 wrench
    # ----------------------

    # Reference adjustment for the difference of the joint frame conventions
    # between Gazebo and DART
    tf.set_identity()
    tf.set_translation(joint_12.get_transform_from_parent_body_node().translation())
    parent_frame01 = dart.SimpleFrame(link_1, "parent_frame01", tf)
    tf.set_identity()
    tf.set_translation(joint_12.get_transform_from_child_body_node().translation())
    child_frame01 = dart.SimpleFrame(link_2, "child_frame01", tf)

    parent_f01 = joint_12.get_wrench_to_parent_body_node(parent_frame01)
    assert np.isclose(parent_f01, [25, -175, 0, 0, 0, 300], rtol=0.01, atol=tol).all()

    child_f01 = joint_12.get_wrench_to_child_body_node(child_frame01)
    assert np.isclose(child_f01, [-25, 175, 0, 0, 0, -300], rtol=0.01, atol=tol).all()

    # ----------------------
    # Test joint_23 wrench
    # ----------------------

    # Reference adjustment for the difference of the joint frame conventions
    # between Gazebo and DART
    tf.set_identity()
    tf.set_translation(joint_23.get_transform_from_parent_body_node().translation())
    parent_frame12 = dart.SimpleFrame(link_2, "parent_frame12", tf)
    tf.set_identity()
    tf.set_translation(joint_23.get_transform_from_child_body_node().translation())
    child_frame12 = dart.SimpleFrame(link_3, "child_frame12", tf)

    parent_f12 = joint_23.get_wrench_to_parent_body_node(parent_frame12)
    assert np.isclose(parent_f12, [25, 0, 0, 0, 0, 50], rtol=0.01, atol=tol).all()

    child_f12 = joint_23.get_wrench_to_child_body_node(child_frame12)
    assert np.isclose(
        child_f12, [-17.678, 0, 17.679, -35.355, 0, -35.355], rtol=0.01, atol=tol
    ).all()


if __name__ == "__main__":
    pytest.main()
